package com.chatplus.application.web.auditlogger;

import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.chatplus.application.common.constant.RedisPrefix;
import com.chatplus.application.dao.framework.AuditLogRuleDao;
import com.chatplus.application.domain.entity.framework.AuditLogRuleEntity;
import com.chatplus.application.enumeration.AuditLogRuleTypeEnum;
import com.chatplus.application.web.auditlogger.properties.AuditLoggerProperties;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import jakarta.annotation.PostConstruct;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.commons.collections4.CollectionUtils;
import org.jetbrains.annotations.NotNull;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.ObjectUtils;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 审计日志黑白名单工具
 *
 * @author chj
 * @date 2022/7/18
 **/
@Component
public class UrlFilterUtil {
    private final StringRedisTemplate redisTemplate;
    private final AuditLogRuleDao auditLogRuleDao;

    private final AuditLoggerProperties auditLoggerProperties;

    public UrlFilterUtil(StringRedisTemplate redisTemplate
            , AuditLogRuleDao auditLogRuleDao
            , AuditLoggerProperties auditLoggerProperties) {
        this.redisTemplate = redisTemplate;
        this.auditLogRuleDao = auditLogRuleDao;
        this.auditLoggerProperties = auditLoggerProperties;
    }

    private final Set<Pattern> excludes = new HashSet<>();
    private final Set<Pattern> includes = new HashSet<>();

    //审计日志redis规则缓存时间,单位: 分钟
    private static final int REDIS_RULE_TIME = 5;

    private static final String key = RedisPrefix.TURN_RIGHT_REDIS_PREFIX + "#AuditLoggerInterceptor";

    private LoadingCache<String, Boolean> servletPathExcludesCache;

    @PostConstruct
    private synchronized void initial() {
        servletPathExcludesCache = CacheBuilder.newBuilder()
            .maximumSize(2000)
            .expireAfterWrite(3, TimeUnit.MINUTES)
            .build(new CacheLoader<>() {
                @NotNull
                @Override
                public Boolean load(@NotNull String key) {
                    return excludesURL(key);
                }
            });
        initRuleCache();
    }

    public boolean excludesURL(HttpServletRequest request) {
        if (!auditLoggerProperties.isEnable()) {
            return true;
        }
        String servletPath = request.getServletPath();
        try {
            if (servletPathExcludesCache == null) {
                initial();
            }
            return servletPathExcludesCache.get(servletPath);
        } catch (Exception xe) {
            return excludesURL(servletPath);
        }
    }

    private void clearRuleCache() {
        this.includes.clear();
        this.excludes.clear();
    }

    private void initRuleCache() {
        if (!auditLoggerProperties.isEnable()) {
            return;
        }
        List<AuditLogRuleEntity> auditLogRuleList = auditLogRuleDao.selectList(Wrappers.emptyWrapper());
        clearRuleCache();
        auditLogRuleList.forEach(item -> {
            String ruleUrl = item.getUrl();
            if (ObjectUtils.isEmpty(ruleUrl)) {
                return;
            }
            if (item.getType() == AuditLogRuleTypeEnum.WHITE_LIST) {
                includes.add(Pattern.compile("^" + ruleUrl.trim()));
            }
            if (item.getType() == AuditLogRuleTypeEnum.BLACK_LIST) {
                excludes.add(Pattern.compile("^" + ruleUrl.trim()));
            }
        });
    }

    private boolean excludesURL(String servletPath) {
        if (Boolean.FALSE.equals(redisTemplate.hasKey(key))) {
            redisTemplate.boundValueOps(key).set("1", REDIS_RULE_TIME, TimeUnit.MINUTES);
            initRuleCache();
        }
        if (CollectionUtils.isEmpty(includes) && CollectionUtils.isEmpty(excludes)) {
            return false;
        }
        boolean result = true;
        //黑名单优先级比白名单高
        if (CollectionUtils.isNotEmpty(includes)) {
            for (Pattern pattern : includes) {
                Matcher m = pattern.matcher(servletPath);
                if (m.find()) {
                    result = false;
                }
            }
        } else {
            //如果为空默认记录所有审计日志，当然必须不能在黑名单内才行
            result = false;
        }
        for (Pattern pattern : excludes) {
            Matcher m = pattern.matcher(servletPath);
            if (m.find()) {
                result = true;
            }
        }
        return result;
    }
}
